{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import IPython.display as ipd\n",
    "\n",
    "import os\n",
    "import json\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import commons\n",
    "import utils\n",
    "from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate\n",
    "from models import SynthesizerTrn\n",
    "from text.symbols import symbols\n",
    "from text import text_to_sequence\n",
    "\n",
    "from scipy.io.wavfile import write\n",
    "\n",
    "\n",
    "def get_text(text, hps):\n",
    "    text_norm = text_to_sequence(text, hps.data.text_cleaners)\n",
    "    if hps.data.add_blank:\n",
    "        text_norm = commons.intersperse(text_norm, 0)\n",
    "    text_norm = torch.LongTensor(text_norm)\n",
    "    return text_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "hps = utils.get_hparams_from_file(\"./configs/ms_ms.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mutli-stream iSTFT VITS\n"
     ]
    }
   ],
   "source": [
    "net_g = SynthesizerTrn(\n",
    "    \n",
    "    len(symbols),\n",
    "    hps.data.filter_length // 2 + 1,\n",
    "    hps.train.segment_size // hps.data.hop_length,\n",
    "    n_speakers=hps.data.n_speakers,\n",
    "    **hps.model)\n",
    "_ = net_g.eval()\n",
    "\n",
    "_ = utils.load_checkpoint(\"./logs/ms_ms/G_18000.pth\", net_g, None)\n",
    "# _ = utils.load_checkpoint(\"logs/ljs_mb_istft_vits/G_1000000.pth\", net_g, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "def tts(txt,device=\"cuda\"):\n",
    "    \n",
    "    stn_tst = get_text(txt, hps)\n",
    "    with torch.no_grad():\n",
    "        x_tst = stn_tst.unsqueeze(0)\n",
    "        x_tst_lengths = torch.LongTensor([stn_tst.size(0)])\n",
    "        spk = torch.LongTensor([3])\n",
    "        t1 = time.time()\n",
    "        audio = net_g.to(device).infer(x_tst.to(device), x_tst_lengths.to(device),sid=spk, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].cpu().data.float().numpy()\n",
    "        t2 = time.time()\n",
    "        print(\"推理时间：\", (t2-t1),\"s\")\n",
    "    ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate))\n",
    "# tts(\"[ZH]大家好啊，我是说的道理，今天来点大家想看的东西[ZH]\",device=\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tts(\"[JA]なんでこんなに慣れてんのよ。私の方が先に好きだったのに[JA]\",device=\"cuda\")\n",
    "# tts(\"[JA]こんにちは。私わあやちねねです。こんにちは。私わあやちねねです。こんにちは。私わあやちねねです。[JA]\",device=\"cuda\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tts(\"[JA]なんでこんなに慣れてんのよ。私の方が先に好きだったのに[JA]\",device=\"cpu\")\n",
    "# tts(\"[JA]こんにちは。私わあやちねねです。こんにちは。私わあやちねねです。こんにちは。私わあやちねねです。[JA]\",device=\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tts(\"[JA]そうだ。俺たちが今まで积み上げてきたもんは全部无駄じゃなかった。こからも俺たちが立ち止まらないかぎり道は続く[JA]\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
